Building Your First t-SNE Embedding

t-SNE has four important hyperparameters that can drastically change the resulting embedding:

Our approach to tuning hyperparameters thus far has been to allow an automated tuning process to choose the best combination for us, through either grid search or random search. But due to its computational cost, most people will run t-SNE with its default hyperparameter values & change them if the embedding doesn’t look sensible. If this sounds very subjective, that’s because it is; but people are usually able to identify visually whether t-SNE is pulling apart clusters of observations nicely.

To give a visual aid for how each of these hyperparameters affects the final embedding, we’ll run t-SNE on our Swiss banknote data using a grid of hyperparameter values.

Above shows the final embeddings with different combinations of theta (rows) & perplexity (columns) using the default values of eta & max_iter. Notice that the clusters become tighter with larger values of perplexity & are lost with very low values. Also notice that for reasonable values of perplexity, the clusters are best resolved when theta is set to 0 (exact t-SNE).

Below shows the final embeddings with different combinations of max_iters (rows) & eta (columns). The effect here is a little more subtle, but smaller values of eta need a larger number of iterations in order to converge (because the cases move in smaller steps at each iteration). For example, for an eta of 100, 1,000 iterations is sufficient to separate the clusters; but with an eta of 1, the clusters remain poorly resolved after 1,000 iterations.

Now that we’re a little more tuned in on how t-SNE’s hyperparameters affect its performance, let’s run t-SNE on our Swiss banknote data set. Just like for PCA, we first select all the columns except the categorical variable (t-SNE also cannot handle categorical variables) & pipe this data into the Rtsne() function. We manually set the values of perplexity, theta, & max_iter hyperparameters & set the argument verbose = TRUE so the algorithm prints a running commentary on what the KL divergence is at each iteration.

data(banknote, package = 'mclust')
swissTib <- as_tibble(banknote)

swissTsne <- select(swissTib, -Status) %>%
  Rtsne(perplexity = 30, theta = 0, max_iter = 5000, verbose = TRUE)
## Performing PCA
## Read the 200 x 6 data matrix successfully!
## Using no_dims = 2, perplexity = 30.000000, and theta = 0.000000
## Computing input similarities...
## Symmetrizing...
## Done in 0.01 seconds!
## Learning embedding...
## Iteration 50: error is 48.730003 (50 iterations in 0.02 seconds)
## Iteration 100: error is 47.504519 (50 iterations in 0.02 seconds)
## Iteration 150: error is 46.447822 (50 iterations in 0.02 seconds)
## Iteration 200: error is 46.140500 (50 iterations in 0.02 seconds)
## Iteration 250: error is 45.660773 (50 iterations in 0.02 seconds)
## Iteration 300: error is 0.375921 (50 iterations in 0.02 seconds)
## Iteration 350: error is 0.305893 (50 iterations in 0.02 seconds)
## Iteration 400: error is 0.277276 (50 iterations in 0.02 seconds)
## Iteration 450: error is 0.272883 (50 iterations in 0.02 seconds)
## Iteration 500: error is 0.270994 (50 iterations in 0.02 seconds)
## Iteration 550: error is 0.270189 (50 iterations in 0.02 seconds)
## Iteration 600: error is 0.269744 (50 iterations in 0.02 seconds)
## Iteration 650: error is 0.269479 (50 iterations in 0.02 seconds)
## Iteration 700: error is 0.269309 (50 iterations in 0.02 seconds)
## Iteration 750: error is 0.269195 (50 iterations in 0.02 seconds)
## Iteration 800: error is 0.269117 (50 iterations in 0.02 seconds)
## Iteration 850: error is 0.269062 (50 iterations in 0.02 seconds)
## Iteration 900: error is 0.269023 (50 iterations in 0.02 seconds)
## Iteration 950: error is 0.268994 (50 iterations in 0.02 seconds)
## Iteration 1000: error is 0.268973 (50 iterations in 0.02 seconds)
## Iteration 1050: error is 0.268957 (50 iterations in 0.02 seconds)
## Iteration 1100: error is 0.268946 (50 iterations in 0.02 seconds)
## Iteration 1150: error is 0.268937 (50 iterations in 0.02 seconds)
## Iteration 1200: error is 0.268930 (50 iterations in 0.02 seconds)
## Iteration 1250: error is 0.268925 (50 iterations in 0.02 seconds)
## Iteration 1300: error is 0.268921 (50 iterations in 0.02 seconds)
## Iteration 1350: error is 0.268918 (50 iterations in 0.02 seconds)
## Iteration 1400: error is 0.268915 (50 iterations in 0.02 seconds)
## Iteration 1450: error is 0.268913 (50 iterations in 0.02 seconds)
## Iteration 1500: error is 0.268912 (50 iterations in 0.02 seconds)
## Iteration 1550: error is 0.268911 (50 iterations in 0.02 seconds)
## Iteration 1600: error is 0.268910 (50 iterations in 0.02 seconds)
## Iteration 1650: error is 0.268909 (50 iterations in 0.02 seconds)
## Iteration 1700: error is 0.268909 (50 iterations in 0.02 seconds)
## Iteration 1750: error is 0.268908 (50 iterations in 0.02 seconds)
## Iteration 1800: error is 0.268908 (50 iterations in 0.02 seconds)
## Iteration 1850: error is 0.268908 (50 iterations in 0.02 seconds)
## Iteration 1900: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 1950: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 2000: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 2050: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 2100: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 2150: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 2200: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 2250: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 2300: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 2350: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 2400: error is 0.268907 (50 iterations in 0.02 seconds)
## Iteration 2450: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 2500: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 2550: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 2600: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 2650: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 2700: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 2750: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 2800: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 2850: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 2900: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 2950: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3000: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3050: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3100: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3150: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3200: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3250: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3300: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3350: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3400: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3450: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3500: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3550: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3600: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3650: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3700: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3750: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3800: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3850: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3900: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 3950: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4000: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4050: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4100: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4150: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4200: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4250: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4300: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4350: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4400: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4450: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4500: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4550: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4600: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4650: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4700: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4750: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4800: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4850: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4900: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 4950: error is 0.268906 (50 iterations in 0.02 seconds)
## Iteration 5000: error is 0.268906 (50 iterations in 0.02 seconds)
## Fitting performed in 1.75 seconds.

Plotting the Result of t-SNE

Let’s plot the two t-SNE dimensions against each other to see how well they separated the genuine & counterfeit banknotes. Because we can’t interpret the axes in terms of how much each variable correlates with them, it’s common for people to colour their t-SNE plots by the values of each of their original variables, to help identify which clusters have higher & lower values. To do this, we first use the mutate_if() function to center the numeric variables in our original data set (by setting .funs = scale & .predicate = is.numeric). We include scale = FALSE to only center the variables, not divide by their standard deviations. The reason we center the variables is that we’re going to shade by their value on the plots, & we don’t want variables with larger values dominating the colour scales.

Next, we mutate two new columns that contain the t-SNE axes values for each vase. Finally, we gather the data so that we can facet by each of the original variables. We plot this data, mapping the value of each original variable to the colour aesthetic & the status of each banknote (genuine vs counterfeit) to the shape aesthetic, & facet by the original variables. We add a custom colour scale gradient to make the colour scale more readable in print.

swissTibTsne <- swissTib %>%
  mutate_if(.funs = scale, .predicate = is.numeric, scale = FALSE) %>%
  mutate(tSNE1 = swissTsne$Y[, 1], tSNE2 = swissTsne$Y[, 2]) %>%
  gather(key = 'Variable', value = 'Value', c(-tSNE1, -tSNE2, -Status))
## Warning: attributes are not identical across measure variables;
## they will be dropped
ggplotly(
  ggplot(swissTibTsne, aes(tSNE1, tSNE2, col = Value, shape = Status)) +
    facet_wrap(~ Variable) +
    geom_point(size = 3) +
    scale_colour_gradient(low = 'dark blue', high = 'cyan') +
    theme_bw()
)